{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using a new dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial we show how you can use a dataset not present in the library.\n",
    "\n",
    "This particular example uses the ENZIMES dataset, uses a simplicial lifting to create simplicial complexes, and trains the SCN2 model. We train the model using the appropriate training and validation datasets, and finally test it on the test dataset."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### <font color='289C4E'>Table of contents<font><a class='anchor' id='top'></a>\n",
    "&emsp;[1. Imports](##sec1)\n",
    "\n",
    "&emsp;[2. Configurations and utilities](##sec2)\n",
    "\n",
    "&emsp;[3. Loading the data](##sec3)\n",
    "\n",
    "&emsp;[4. Model initialization](##sec4)\n",
    "\n",
    "&emsp;[5. Training](##sec5)\n",
    "\n",
    "&emsp;[6. Testing the model](##sec6)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Imports <a class=\"anchor\" id=\"sec1\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import lightning as pl\n",
    "# Hydra related imports\n",
    "from omegaconf import OmegaConf\n",
    "# Dataset related imports\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from topobench.dataloader.dataloader import TBDataloader\n",
    "from topobench.data.preprocessor import PreProcessor\n",
    "# Model related imports\n",
    "from topobench.model.model import TBModel\n",
    "from topomodelx.nn.simplicial.scn2 import SCN2\n",
    "from topobench.nn.wrappers.simplicial import SCNWrapper\n",
    "from topobench.nn.encoders import AllCellFeatureEncoder\n",
    "from topobench.nn.readouts import PropagateSignalDown\n",
    "# Optimization related imports\n",
    "from topobench.loss.loss import TBLoss\n",
    "from topobench.optimizer import TBOptimizer\n",
    "from topobench.evaluator.evaluator import TBEvaluator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Configurations and utilities <a class=\"anchor\" id=\"sec2\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Configurations can be specified using yaml files or directly specified in your code like in this example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "transform_config = { \"clique_lifting\":\n",
    "    {\"transform_type\": \"lifting\",\n",
    "    \"transform_name\": \"SimplicialCliqueLifting\",\n",
    "    \"complex_dim\": 3,}\n",
    "}\n",
    "\n",
    "split_config = {\n",
    "    \"learning_setting\": \"inductive\",\n",
    "    \"split_type\": \"random\",\n",
    "    \"data_seed\": 0,\n",
    "    \"data_split_dir\": \"./data/ENZYMES/splits/\",\n",
    "    \"train_prop\": 0.5,\n",
    "}\n",
    "\n",
    "in_channels = 3\n",
    "out_channels = 6\n",
    "dim_hidden = 16\n",
    "\n",
    "wrapper_config = {\n",
    "    \"out_channels\": dim_hidden,\n",
    "    \"num_cell_dimensions\": 3,\n",
    "}\n",
    "\n",
    "readout_config = {\n",
    "    \"readout_name\": \"PropagateSignalDown\",\n",
    "    \"num_cell_dimensions\": 1,\n",
    "    \"hidden_dim\": dim_hidden,\n",
    "    \"out_channels\": out_channels,\n",
    "    \"task_level\": \"graph\",\n",
    "    \"pooling_type\": \"sum\",\n",
    "}\n",
    "\n",
    "loss_config = {\n",
    "    \"dataset_loss\": \n",
    "        {\n",
    "            \"task\": \"classification\", \n",
    "            \"loss_type\": \"cross_entropy\"\n",
    "        }\n",
    "}\n",
    "\n",
    "evaluator_config = {\"task\": \"classification\",\n",
    "                    \"num_classes\": out_channels,\n",
    "                    \"metrics\": [\"accuracy\", \"precision\", \"recall\"]}\n",
    "\n",
    "optimizer_config = {\"optimizer_id\": \"Adam\",\n",
    "                    \"parameters\":\n",
    "                        {\"lr\": 0.001,\"weight_decay\": 0.0005}\n",
    "                    }\n",
    "\n",
    "transform_config = OmegaConf.create(transform_config)\n",
    "split_config = OmegaConf.create(split_config)\n",
    "readout_config = OmegaConf.create(readout_config)\n",
    "loss_config = OmegaConf.create(loss_config)\n",
    "evaluator_config = OmegaConf.create(evaluator_config)\n",
    "optimizer_config = OmegaConf.create(optimizer_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def wrapper(**factory_kwargs):\n",
    "    def factory(backbone):\n",
    "        return SCNWrapper(backbone, **factory_kwargs)\n",
    "    return factory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Loading the data <a class=\"anchor\" id=\"sec3\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example we use the ENZYMES dataset. It is a graph dataset and we use the clique lifting to transform the graphs into simplicial complexes. We invite you to check out the README of the [repository](https://github.com/pyt-team/TopoBenchX) to learn more about the various liftings offered."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Transform parameters are the same, using existing data_dir: ./data/ENZYMES/clique_lifting/3206123057\n"
     ]
    }
   ],
   "source": [
    "dataset_dir = \"./data/ENZYMES/\"\n",
    "dataset = TUDataset(root=dataset_dir, name=\"ENZYMES\")\n",
    "\n",
    "preprocessor = PreProcessor(dataset, dataset_dir, transform_config)\n",
    "dataset_train, dataset_val, dataset_test = preprocessor.load_dataset_splits(split_config)\n",
    "datamodule = TBDataloader(dataset_train, dataset_val, dataset_test, batch_size=32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Model initialization <a class=\"anchor\" id=\"sec4\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can create the backbone by instantiating the SCN2 model from TopoModelX. Then the `SCNWrapper` and the `TBModel` take care of the rest."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "backbone = SCN2(in_channels_0=dim_hidden, in_channels_1=dim_hidden, in_channels_2=dim_hidden)\n",
    "wrapper = wrapper(**wrapper_config)\n",
    "\n",
    "readout = PropagateSignalDown(**readout_config)\n",
    "loss = TBLoss(**loss_config)\n",
    "feature_encoder = AllCellFeatureEncoder(in_channels=[in_channels, in_channels, in_channels], out_channels=dim_hidden)\n",
    "\n",
    "evaluator = TBEvaluator(**evaluator_config)\n",
    "optimizer = TBOptimizer(**optimizer_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TBModel(backbone=backbone,\n",
    "                 backbone_wrapper=wrapper,\n",
    "                 readout=readout,\n",
    "                 loss=loss,\n",
    "                 feature_encoder=feature_encoder,\n",
    "                 evaluator=evaluator,\n",
    "                 optimizer=optimizer,\n",
    "                 compile=False,)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Training <a class=\"anchor\" id=\"sec5\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can use the `lightning` trainer to train the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:44: Attribute 'backbone_wrapper' removed from hparams because it cannot be pickled. You can suppress this warning by setting `self.save_hyperparameters(ignore=['backbone_wrapper'])`.\n",
      "\n",
      "  | Name            | Type                  | Params | Mode \n",
      "------------------------------------------------------------------\n",
      "0 | feature_encoder | AllCellFeatureEncoder | 1.2 K  | train\n",
      "1 | backbone        | SCNWrapper            | 1.6 K  | train\n",
      "2 | readout         | PropagateSignalDown   | 102    | train\n",
      "3 | val_acc_best    | MeanMetric            | 0      | train\n",
      "------------------------------------------------------------------\n",
      "2.9 K     Trainable params\n",
      "0         Non-trainable params\n",
      "2.9 K     Total params\n",
      "0.012     Total estimated model params size (MB)\n",
      "36        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassAccuracy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassPrecision was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The ``compute`` method of metric MulticlassRecall was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.\n",
      "  warnings.warn(*args, **kwargs)\n",
      "/home/levtel/projects/dev/TopoBench/topobench/nn/wrappers/simplicial/scn_wrapper.py:75: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:53.)\n",
      "  normalized_matrix = diag_matrix @ (matrix @ diag_matrix)\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n",
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (10) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "`Trainer.fit` stopped: `max_epochs=5` reached.\n"
     ]
    }
   ],
   "source": [
    "#%%capture\n",
    "# Increase the number of epochs to get better results\n",
    "trainer = pl.Trainer(max_epochs=5, accelerator=\"cpu\", enable_progress_bar=False)\n",
    "\n",
    "trainer.fit(model, datamodule)\n",
    "train_metrics = trainer.callback_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Training metrics\n",
      " --------------------------\n",
      "train/accuracy:       0.1867\n",
      "train/precision:      0.1796\n",
      "train/recall:         0.1835\n",
      "val/loss:             3.2280\n",
      "val/accuracy:         0.1600\n",
      "val/precision:        0.1735\n",
      "val/recall:           0.1554\n",
      "train/loss:           3.2763\n"
     ]
    }
   ],
   "source": [
    "print('      Training metrics\\n', '-'*26)\n",
    "for key in train_metrics:\n",
    "    print('{:<21s} {:>5.4f}'.format(key+':', train_metrics[key].item()))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Testing the model <a class=\"anchor\" id=\"sec6\"></a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can test the model and obtain the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/levtel/miniconda3/envs/topobench/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test/accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.18000000715255737    </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     3.625048875808716     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test/precision       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.1994038224220276     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test/recall        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.1821674406528473     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test/accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.18000000715255737   \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    3.625048875808716    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test/precision      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.1994038224220276    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test/recall       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.1821674406528473    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "trainer.test(model, datamodule)\n",
    "test_metrics = trainer.callback_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "      Testing metrics\n",
      " -------------------------\n",
      "test/loss:           3.6250\n",
      "test/accuracy:       0.1800\n",
      "test/precision:      0.1994\n",
      "test/recall:         0.1822\n"
     ]
    }
   ],
   "source": [
    "print('      Testing metrics\\n', '-'*25)\n",
    "for key in test_metrics:\n",
    "    print('{:<20s} {:>5.4f}'.format(key+':', test_metrics[key].item()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "topobench",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
